# 虽然默认的即时执行模式（Eager Execution）为我们带来了灵活及易调试的特性，但在特定的场合，例如追求高性能或部署模型时，我们依然希望使用
# TensorFlow 1.X 中默认的图执行模式（Graph Execution），将模型转换为高效的 TensorFlow 图模型。
# 此时，TensorFlow 2 为我们提供了 tf.function 模块，结合 AutoGraph 机制，
# 使得我们仅需加入一个简单的 @tf.function 修饰符，就能轻松将模型以图执行模式运行。

import tensorflow as tf
import time
from zh.model.mnist.cnn import CNN
from zh.model.utils import MNISTLoader

num_batches = 1000
batch_size = 50
learning_rate = 0.001
data_loader = MNISTLoader()
model = CNN()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)


@tf.function
def train_one_step(X, y):
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        # 注意这里使用了TensorFlow内置的tf.print()。@tf.function不支持Python内置的print方法
        tf.print("loss", loss)
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))


start_time = time.time()
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    train_one_step(X, y)
end_time = time.time()
print(end_time - start_time)

# 运行 400 个 Batch 进行测试，
# 加入 @tf.function 的程序耗时 35.5 秒，
# 未加入 @tf.function 的纯即时执行模式程序耗时 43.8 秒。
# 可见 @tf.function 带来了一定的性能提升。
# 一般而言，当模型由较多小的操作组成的时候， @tf.function 带来的提升效果较大。
# 而当模型的操作数量较少，但单一操作均很耗时的时候，则 @tf.function 带来的性能提升不会太大
